{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lu6HYEefHRKE"
      },
      "outputs": [],
      "source": [
        "!pip install langfun --pre"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p97Z_aAAHht-"
      },
      "outputs": [],
      "source": [
        "import re\n",
        "import random\n",
        "import pandas as pd\n",
        "import langfun as lf\n",
        "import pyglove as pg\n",
        "from typing import Literal, Annotated"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lDiFrOJjIigu"
      },
      "outputs": [],
      "source": [
        "openai_key = \"\u003cyour OpenAI key\u003e\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mRWwhgNCIko8"
      },
      "outputs": [],
      "source": [
        "generate_lm = lf.llms.Gpt4_0613(temperature=0.3, max_tokens=512, timeout=100, stop=[\"```\\n\"], api_key=openai_key)\n",
        "valid_lm = lf.llms.Gpt4_0613(temperature=0.0, max_tokens=512, timeout=100, stop=[\"```\\n\"], api_key=openai_key)\n",
        "verifier_lm = lf.llms.Gpt4_0613(temperature=0.3, max_tokens=512, timeout=100, stop=[\"```\\n\"], api_key=openai_key)\n",
        "answer_lm = lf.llms.Gpt4_0613(temperature=0.0, max_tokens=512, timeout=100, stop=[\"```\\n\"], api_key=openai_key)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6xeSoRe7IrzP"
      },
      "source": [
        "# Game24 Solver"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mp-baVydJJeu"
      },
      "source": [
        "## Schemas (41 lines of code)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zlipYXdwIiak"
      },
      "outputs": [],
      "source": [
        "class Step(pg.Object):\n",
        "  expression: str\n",
        "  left_numbers: list[int]\n",
        "\n",
        "\n",
        "class GenerateInput(pg.Object):\n",
        "  input: list[int]\n",
        "  expression_blacklist: Annotated[list[str], \"Expressions that should be skipped.\"]\n",
        "\n",
        "class Generate(pg.Object):\n",
        "  expression: Annotated[str, \"Expression with the two input numbers and basic arithmetic operations (+ - * /). Calculate the answer correctly. The Expression should be different from the blacklist.\"]\n",
        "  numbers_from_input: list[int]\n",
        "  remain_numbers_in_input: list[int]\n",
        "  result_of_expression: int | float\n",
        "\n",
        "class ValidationInput(pg.Object):\n",
        "  input: list[int]\n",
        "  step: Step\n",
        "\n",
        "class Validation(pg.Object):\n",
        "  expression: str\n",
        "  input_check: Annotated[list[str], \"Check whether the numbers in the expression including repetitions are a part of, or the complete list of, the given input.\"]\n",
        "  math_correctness_check: str\n",
        "  left_numbers_check: Annotated[str, \"Check whether the left numbers of step are the numbers that are not included in the expression, combined with the result of the expression.\"]\n",
        "  judgement: Annotated[Literal[\"Valid\", \"Invalid\"], \"Conclude whether the step pass all validations: math check, input check and the verification of left numbers.\"]\n",
        "\n",
        "class VerifierInput(pg.Object):\n",
        "  input: list[int]\n",
        "  expression_blacklist: Annotated[list[str], \"Expressions that should be skipped in the drafts.\"]\n",
        "\n",
        "class Verifier(pg.Object):\n",
        "  expression_drafts: Annotated[list[str], \"Expressions with all input numbers and basic arithmetic operations (+ - * /). Calculate the answer step by step. List ~15 unique drafts.\"]\n",
        "  expression_that_obtains_24: str | None\n",
        "\n",
        "class Game24SolutionInput(pg.Object):\n",
        "  numbers_in_final_expression: list[int]\n",
        "  hints: list[str]\n",
        "\n",
        "class Game24Solution(pg.Object):\n",
        "  figure_out_solution_step_by_step: list[str]\n",
        "  final_expression: str | None"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bvNL6zAXJPEh"
      },
      "source": [
        "## One-shot examples (treated as data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9pzMsgdxImPy"
      },
      "outputs": [],
      "source": [
        "generate_few_shots = [\n",
        "    lf.structured.mapping.MappingExample(\n",
        "        input=GenerateInput(\n",
        "            input=[1, 4, 8, 11], expression_blacklist=[\"1 + 4 = 5\"]\n",
        "        ),\n",
        "        output=Generate(\n",
        "            expression=\"1 + 11 = 12\",\n",
        "            numbers_from_input=[1, 11],\n",
        "            remain_numbers_in_input=[4, 8],\n",
        "            result_of_expression=12,\n",
        "        ),\n",
        "    ),\n",
        "]\n",
        "\n",
        "valid_few_shots = [\n",
        "    lf.structured.mapping.MappingExample(\n",
        "        input=ValidationInput(\n",
        "            input=[4, 7, 13],\n",
        "            step=Step(\n",
        "                expression=\"7 * (13 / 4) = 7 * 3 = 21\",\n",
        "                left_numbers=[21],\n",
        "            ),\n",
        "        ),\n",
        "        output=Validation(\n",
        "            expression=\"7 * (13 / 4)\",\n",
        "            input_check=[\n",
        "                \"1. Extract the numbers from the expression by removing all operators and parentheses from '7 * (13 / 4)'. We get all numbers of the expression are 7, 13, and 4.\",\n",
        "                \"2. The first thing is to check the extracted numbers' size is less than or equal to the input numbers' size. The extracted list [7, 13, 4] has 3 number while the input list [4, 7, 13] has 3 numbers. So the size check passes.\",\n",
        "                \"3. Then check if all numbers from extracted list [7, 13, 4] are in the input list [4, 7, 13]. First, [7, 13, 4] has 7 and [4, 7, 13] has 7. After removing 7 from both lists, we get [13, 4] and [4, 13]. Then, [13, 4] has 13 and [4, 13] has 13. After removing 13 from both lists, we get [4] and [4]. As 4 is in both lists, all numbers from extracted list [7, 13, 4] are in the input list [4, 7, 13]. So the check passes.\",\n",
        "            ],\n",
        "            math_correctness_check=(\n",
        "                \"7 * (13 / 4) = 7 * 3.25 = 22.75. So 21 is\"\n",
        "                \" incorrect result of the expression. So the check fails.\"\n",
        "            ),\n",
        "\n",
        "            left_numbers_check=(\n",
        "                \"As all input numbers [4, 7, 13] are included in the\"\n",
        "                \" expression, the only left numbers are the result of the\"\n",
        "                \" expression 22.75. So the correct left numbers are [22.75].\"\n",
        "                \" The left numbers [21] is incorrect. So the check fails.\"\n",
        "            ),\n",
        "            judgement=\"Invalid\",\n",
        "        ),\n",
        "    ),\n",
        "]\n",
        "\n",
        "verifier_few_shots = [\n",
        "    lf.structured.mapping.MappingExample(\n",
        "        input=VerifierInput(\n",
        "            input=[1, 2, 11],\n",
        "            expression_blacklist=[\n",
        "                \"(1 + 11) / 2 = 12 / 2 = 6\",\n",
        "                \"1 + 11 - 2 = 12 - 2 = 10\",\n",
        "                \"(1 + 2) * 11 = 3 * 11 = 33\",\n",
        "                \"(1 + 2) − 11 = 3 - 11 = -8\",\n",
        "                \"(1 + 2) / 11 = 3 / 11 = 0.2727\",\n",
        "                \"(1 − 11) * 2 = -10 * 2 = -20\",\n",
        "                \"(1 − 11) − 2 = -10 - 2 = -12\",\n",
        "            ],\n",
        "        ),\n",
        "        output=Verifier(\n",
        "            expression_drafts=[\n",
        "                \"1 * 2 + 11 = 2 + 11 = 13\",\n",
        "                \"1 + 2 + 11 = 3 + 11 = 14\",\n",
        "                \"2 * (11 - 1) = 2 * 10 = 20\",\n",
        "                \"11 * 2 - 1 = 22 - 1 = 21\",\n",
        "                \"2 / 1 * 11 = 2 * 11 = 22\",\n",
        "                \"(11 - 1) / 2 = 10 / 2 = 5\",\n",
        "                \"2 * (11 + 1) = 2 * 12 = 24\",\n",
        "                \"11 - 1 * 2 = 11 - 2 = 9\",\n",
        "                \"1 * (2 + 11) = 1 * 13 = 13\",\n",
        "                \"1 * (11 - 2) = 1 * 9 = 9\",\n",
        "                \"2 * 11 + 1 = 22 + 1 = 23\",\n",
        "                \"(1 * 2) / 11 = 2 / 11 = 0.1818\",\n",
        "                \"(1 - 11) / 2 = -10 / 2 = -5\",\n",
        "                \"11 - (1 + 2) = 11 - 3 = 8\",\n",
        "                \"(2 - 1) * 11 = 1 * 11 = 11\",\n",
        "            ],\n",
        "            expression_that_obtains_24=\"2 * (11 + 1) = 2 * 12 = 24\",\n",
        "        ),\n",
        "    ),\n",
        "]\n",
        "\n",
        "game24_solution_few_shots = [\n",
        "    lf.structured.mapping.MappingExample(\n",
        "        Game24SolutionInput(\n",
        "            numbers_in_final_expression=[4, 4, 8, 12],\n",
        "            hints=[\n",
        "                \"4 - 12 = -8\",\n",
        "                \"4 * 8 - 8 = 32 - 8 = 24\",\n",
        "            ],\n",
        "        ),\n",
        "        Game24Solution(\n",
        "            figure_out_solution_step_by_step=[\n",
        "                \"1. As hint 1 inludes two input numbers and hint 2 includes the other two input numbers plus the result from hint 1, the plan is to bring hint 1 into hint 2 to form the final expression with the 4 input numbers.\",\n",
        "                \"2. As hint 2 leads to the final answer, let's start from that. '4 * 8 - 8 = 32 - 8 = 24' suggests that we can use the numbers [4, 8, 8] to get 24.\",\n",
        "                \"3. Then, the hint 1 '4 - 12 = -8' suggests that we can replace a 8, which is a number doesn't come from input in hint 2 expression, with the expression '12 - 4', which are two input numbers.\",\n",
        "                \"4. Combining all the information, we can substitute 8 with '12 - 4' in the expression '4 * 8 - 8' to get the final expression '4 * (12 - 4) - 8', where 4 numbers in the expression are all from the input.\",\n",
        "            ],\n",
        "            final_expression=\"4 * (12 - 4) - 8\",\n",
        "        ),\n",
        "    ),\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7aaa-jiFI4fN"
      },
      "source": [
        "## GameOf24 Agent (42 lines of code)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "biaertr5I1Ke"
      },
      "outputs": [],
      "source": [
        "def game_24_solver(example, num_iter=50, num_verification=2):\n",
        "  next_step_blacklist=[]\n",
        "  for i in range(num_iter):\n",
        "    # Generate next step.\n",
        "    p = lf.query(GenerateInput(input=example, expression_blacklist=next_step_blacklist), Generate, lm=generate_lm, examples=generate_few_shots, default=None)\n",
        "    if p is None:\n",
        "      continue\n",
        "    next_step_blacklist.append(p.expression)\n",
        "    if isinstance(p.result_of_expression, float):\n",
        "      continue\n",
        "    next_step = Step(expression=p.expression, left_numbers=p.remain_numbers_in_input + [p.result_of_expression])\n",
        "\n",
        "    # Validate the next step.\n",
        "    validation = lf.query(ValidationInput(input=example, step=next_step), Validation, lm=valid_lm, examples=valid_few_shots, default=None)\n",
        "    if validation and validation.judgement == \"Invalid\":\n",
        "      continue\n",
        "\n",
        "    expression_blacklist = []\n",
        "    for _ in range(num_verification):\n",
        "      random.shuffle(expression_blacklist)\n",
        "      num_blacklist = random.randint(0, min(15, len(expression_blacklist)))\n",
        "      trimed_blacklist = expression_blacklist[:num_blacklist]\n",
        "      # Verify how possible the next step may lead to 24.\n",
        "      verification = lf.query(VerifierInput(input=next_step.left_numbers, expression_blacklist=trimed_blacklist), Verifier, lm=verifier_lm, examples=verifier_few_shots, default=None)\n",
        "      if verification is None:\n",
        "        continue\n",
        "\n",
        "      # If the next step can leads to 24, output the solution.\n",
        "      if verification.expression_that_obtains_24:\n",
        "        final_step = Step(expression=verification.expression_that_obtains_24, left_numbers=[24])\n",
        "        # Valid the final step.\n",
        "        validation = lf.query(ValidationInput(input=next_step.left_numbers, step=final_step), Validation, lm=valid_lm, examples=valid_few_shots, default=None)\n",
        "\n",
        "        if isinstance(validation, Validation) and validation.judgement == \"Valid\":\n",
        "          # Generate the final expression from steps.\n",
        "          ans = lf.query(Game24SolutionInput(numbers_in_final_expression=example, hints=[next_step.expression, final_step.expression]), Game24Solution, lm=answer_lm, examples=game24_solution_few_shots, default=None)\n",
        "          if isinstance(ans, Game24Solution):\n",
        "            return ans\n",
        "\n",
        "      expression_blacklist.extend(verification.expression_drafts)\n",
        "\n",
        "  return None"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hHq28WlxJAfs"
      },
      "source": [
        "# Evaluation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5a48VoCwJCLQ"
      },
      "source": [
        "## Load Eval Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LZBaU3yHI9dG"
      },
      "outputs": [],
      "source": [
        "df = pd.read_csv('https://raw.githubusercontent.com/princeton-nlp/tree-of-thought-llm/master/src/tot/data/24/24.csv')\n",
        "examples = [list(map(int, e.split(' '))) for e in df[900:1000]['Puzzles'].tolist()]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xma7FaxzJGzl"
      },
      "source": [
        "## Utils"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7sTS4SPmJKfQ"
      },
      "outputs": [],
      "source": [
        "def answer(output: pg.Dict, example) -\u003e int:\n",
        "  \"\"\"Evaluate the numerical answer of the final expression from the LLM agent.\"\"\"\n",
        "  # Make sure the expression use the exact 4 given numbers.\n",
        "  if output.final_expression:\n",
        "    # Sometimes the expression includes \" = 24\" which should be removed.\n",
        "    exp = output.final_expression.split('=')[0]\n",
        "    numbers = re.findall(r'-?\\d+', exp)\n",
        "    sorted_list = sorted([int(number) for number in numbers])\n",
        "    if sorted_list != sorted(list(example)):\n",
        "      return 0\n",
        "  else:\n",
        "    return 0\n",
        "\n",
        "  try:\n",
        "    return lf.coding.evaluate(exp)\n",
        "  except Exception:\n",
        "    return 0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z_vdC_09JO-i"
      },
      "source": [
        "## Run the Eval"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o7QFnxHAJMz-"
      },
      "outputs": [],
      "source": [
        "# Num of workers working in parallel.\n",
        "max_workers = 10\n",
        "\n",
        "success = 0\n",
        "total = 0\n",
        "parse_error = 0\n",
        "failed_to_solve = 0\n",
        "wrong_solution = 0\n",
        "failed_to_solve_list = []\n",
        "for e, ans, error in lf.concurrent_map(game_24_solver, examples, max_workers=max_workers, show_progress=True):\n",
        "  total += 1\n",
        "\n",
        "  if ans is None:\n",
        "    failed_to_solve += 1\n",
        "    print(f'[FAILED TO SOLVE] {e}')\n",
        "    failed_to_solve_list.append(e)\n",
        "    continue\n",
        "\n",
        "  if answer(ans, e) == 24:\n",
        "    success += 1\n",
        "    print(f'[SUCCESS] {e} {ans.final_expression}')\n",
        "  else:\n",
        "    wrong_solution += 1\n",
        "    print(f'[FAILED] {e} {ans}')\n",
        "\n",
        "  acc = success / total\n",
        "  parse_error_rate = parse_error / total\n",
        "  failed_to_solve_rate = failed_to_solve / total\n",
        "  wrong_solution_rate = wrong_solution / total\n",
        "  print(f'acc: {acc}, total: {total}, parse_error_rate: {parse_error_rate}, failed_to_solve_rate: {failed_to_solve_rate}, wrong_solution_rate: {wrong_solution_rate}')\n",
        "\n",
        "acc = success / total\n",
        "parse_error_rate = parse_error / total\n",
        "failed_to_solve_rate = failed_to_solve / total\n",
        "wrong_solution_rate = wrong_solution / total\n",
        "print(f'acc: {acc}, total: {total}, parse_error_rate: {parse_error_rate}, failed_to_solve_rate: {failed_to_solve_rate}, wrong_solution_rate: {wrong_solution_rate}')\n",
        "\n",
        "print(f'failed_to_solve_list: {failed_to_solve_list}')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": [
        {
          "file_id": "1qeWhkhFbHlXs4jojIl9xowo3lvBJhr0T",
          "timestamp": 1716244600293
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
